"""生成 n 个随机数，使其和固定"""

import random
import numpy as np


def fixed_rvs(n, total, limits: tuple, func, *args, **kwargs):
    """
    生成 n 个和为 total 的随机数（每个随机数由 func(*args, **kwargs) 产生）

    == 参数 ==
    n      : 生成的随机数个数
    total  : 随机数之和
    limits : 每个随机数上下限（以二元组的形式传入）
    func   : 单个随机数生成函数
    args   : func 接收的可变参数
    kwargs : func 接收的可变关键字参数

    == 思路 ==

    每个随机数应满足：
    1. 在范围 limits 内
    2. 保证后续的随机数可以正常生成
       比如生成 3 个和为 10 的随机数，而随机数生成函数只能保证生成范围在 1~5 的随机数
       如前 2 个生成的随机数为 [1, 2]，则第 3 个随机数无论如何生成都无法确保和为 10，
       故需在每个利用 func 生成随机数后进行上下限确认，上下限应为：
       下限*：total - (待生成的随机数个数 - 1) * 基本上限（limits[1]）
       上限*：total - (待生成的随机数个数 - 1) * 基本下限（limits[0]）
       如在此区间则加入返回序列中

    == 注意 ==

    生成函数的取值分布应覆盖住输入的上下限，否则可能导致生成缓慢或无限循环

    """

    res = []
    lower, upper = limits

    for i in range(n - 1):
        lower_ = max(lower, total - (n - i - 1) * upper)
        upper_ = min(upper, total - (n - i - 1) * lower)
        while True:
            rn = func(*args, **kwargs)
            if lower_ <= rn <= upper_:
                res.append(rn)
                total -= rn
                break

    res.append(total)

    return res


if __name__ == "__main__":
    # 测试不同随机数生成函数

    n = 10000  # 测试例数
    m = 10  # 每组生成数
    t = 600  # 每组总和
    lim = (10, 90)  # 最值上下限

    genfun_params = [
        (random.randint, 0, 100),
        (lambda *x: int(random.gauss(*x)), t / m, 1),
        (lambda *x: int(random.gauss(*x)), t / m, 10),
        (lambda *x: int(random.gauss(*x)), t / m, 20),
        (lambda *x: int(random.triangular(*x)), 0, 100, 60),
        (lambda *x: int(random.triangular(*x)), 0, 100, 40),
    ]

    for i, p in enumerate(genfun_params, 1):
        print(f'#{i} | {p[0].__name__}{p[1:]}')
        rns = np.array([fixed_rvs(m, t, lim, *p) for _ in range(n)])

        # 合规检查
        valid = all([sum(rn) == t
                     and lim[0] <= min(rn)
                     and max(rn) <= lim[1] for rn in rns])
        print(f'\tvalid: {valid}')

        # 平均标准差
        std = np.mean([np.std(rn) for rn in rns])
        print(f'\tstd: {std:.4f}')

        # 抽样
        dat = random.choice(rns)
        print(f'\tsample: {", ".join(map(str, dat))}')

# #1 | randint(0, 100)
#    valid: True
#    std: 24.3537
#    sample: 74, 74, 27, 69, 77, 55, 87, 14, 66, 57
# #2 | <lambda>(60.0, 1)
#    valid: True
#    std: 1.8978
#    sample: 60, 59, 62, 60, 60, 59, 60, 60, 61, 59
# #3 | <lambda>(60.0, 10)
#    valid: True
#    std: 11.5528
#    sample: 53, 52, 70, 46, 60, 49, 60, 63, 70, 77
# #4 | <lambda>(60.0, 20)
#    valid: True
#    std: 18.2265
#    sample: 50, 69, 46, 63, 81, 70, 59, 68, 66, 28
# #5 | <lambda>(0, 100, 60)
#    valid: True
#    std: 20.6395
#    sample: 67, 29, 86, 52, 48, 45, 18, 87, 78, 90
# #6 | <lambda>(0, 100, 40)
#    valid: True
#    std: 23.2748
#    sample: 69, 29, 48, 81, 11, 36, 72, 78, 89, 87
